# Project README

## Overview

This repository supports inference across **multiple datasets** with **all baseline methods** and **our proposed KITE method**.  
All evaluations are based on accuracy and can be executed using simple shell scripts.

---

## Available Datasets

The following datasets are supported (task names used in scripts):

```
- cmsqa
- mnli
- mrpc
- qnli
- sst5
- swag
```

| Type          | Dataset       | Task                      | #Train   | #Validation | #ICE |
|---------------|---------------|---------------------------|----------|-------------|------|
| Classification| SST-5         | Sentiment Analysis        | 8,534    | 1,101       | 40   |
|               | MRPC          | Paraphrase Detection      | 3,668    | 408         | 27   |
|               | MNLI          | Natural Language Inference| 392,568  | 19,647      | 40   |
|               | QNLI          | Natural Language Inference| 104,707  | 5,463       | 27   |
|               | CMSQA         | Commonsense Reasoning     | 9,740    | 1,221       | 50   |
|               | HellaSwag     | Commonsense Reasoning     | 52,611   | 20,006      | 50   |


Refer to **Table 7** in [this paper](https://arxiv.org/pdf/2302.05698) for more details on each dataset.

---

## Note:
- just put name of dataset in `run_for_dataset`, it will get the scores for all 2 baselines and our method.

## Baselines

### 🔹 Baseline 1: No In-Context Examples (Zero-shot)

1. Open the `run_inference.sh` script.
2. Set the following variables:
   - `task_name`: Choose from the list above.
   - `model_path`: Path to an LLM or saved model checkpoint.
   - `number_of_ice_examples=0`
   - `retrieve_file=""` (empty string)
3. Run inference:
   ```bash
   bash run_inference.sh
   ```
4. Accuracy will be printed in the logs.

---

### 🔹 Baseline 2: BM25 Retriever

1. Open `run_bm25.sh` and set the `task_name`.
2. Run:
   ```bash
   bash run_bm25.sh
   ```
   This saves the retriever output to `output/<task_name>/`.
3. To evaluate:
   - Set `retrieve_file` in `run_inference.sh` to the generated file path.
   - Run:
     ```bash
     bash run_inference.sh
     ```

---

### 🔹 Baseline 3: Dense Retriever

Same steps as BM25, but using the `run_dense.sh` script.

1. Set `task_name` in `run_dense.sh`.
2. Run:
   ```bash
   bash run_dense.sh
   ```
3. Use the generated retriever file with `run_inference.sh`, just like in Baseline 2.

---

## 🌟 Our Method: KITE

1. Open and configure `run_submodular.sh`:
   - `model_name`: Model used to generate embeddings.
   - `lambda_val`: Weight for the submodular objective.
   - `number_of_ice_examples`: Number of in-context examples to select.
   - `run_for_n_samples`: (Optional) Limit number of samples for debugging. Remove for full dataset run.
2. Run:
   ```bash
   bash run_submodular.sh
   ```
3. The retriever output will be saved in `output/<task_name>/`.
4. Point `retrieve_file` in `run_inference.sh` to this output file and run:
   ```bash
   bash run_inference.sh
   ```

---


## Running Parameters 

Customize the behavior of retrievers, inference, and dataset subsets using the following parameters. These are typically set in your shell scripts before execution.

#### Example Configuration

```bash
task_name="sst5"
number_of_ice_examples=50 
lambda_val=1                       # for submodular retriever
model_name="bert-base-uncased"     # sentence encoder model name
pred_dir="output/${task_name}/results/"

## LLM configs 
llm_model_path="gpt-neo-2.7B"                          # Path to the LLM model
n_tokens=1600                                          # Number of tokens to use for ICE
batch_size=1                                           # Batch size for inference


## Dataset subset configs
example_bank_size=null           # Number of examples in the example bank
example_bank_segment=null           # use first 1000 examples, 1 will give next 1000 examples
test_set_size=null                # Number of questions in the test set
test_set_segment=null               # use first 250 examples, 1 will give next 250 examples
```

#### Parameter Details 

| Parameter              | Description                                                                 |
|------------------------|-----------------------------------------------------------------------------|
| `task_name`            | Dataset to run (e.g., `cmsqa`, `mnli`, etc.). Used for organizing outputs. |
| `number_of_ice_examples` | Number of in-context examples to use during inference.                      |
| `lambda_val`           | Trade-off parameter for diversity vs. relevance in submodular selection.    |
| `model_name`           | Pre-trained model name for embedding generation (e.g., BERT).               |
| `pred_dir`             | Directory to store inference results.                                       |
| `llm_model_path`       | Local path to the LLM used for ICE-based inference.                         |
| `n_tokens`             | Total number of tokens allowed per prompt (ICE + query).                    |
| `batch_size`           | Number of queries processed together during inference.                      |
| `example_bank_size`    | Total size of the retrieval bank from which ICEs are chosen.                |
| `example_bank_segment` | Segment offset to access different parts of the example bank.               |
| `test_set_size`        | Number of test questions used for evaluation.                               |
| `test_set_segment`     | Segment offset to run a different batch of test examples.                   |

